package cn.zhxu.toys.msg;

import cn.zhxu.toys.cache.CacheService;
import cn.zhxu.toys.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.List;
import java.util.Objects;

/**
 * 带限速功能的短信发送器
 * @author Troy.Zhou @ 2022/8/17
 * */
public class RateLimitMsgSender implements MsgSender {

	private static final Logger log = LoggerFactory.getLogger(RateLimitMsgSender.class);

	/** 限速标签 */
	private final String tag;

	/** 实际的消息发送器 */
	private MsgSender msgSender;

	/** KEY 分析器 */
	private KeyResolver keyResolver = phone -> phone;

	/** KEY 白名单 */
	private List<String> whiteList = Collections.emptyList();

	/** 单位限速周期 */
	private long unitSeconds = 12 * 3600;

	/** 单位周期内 每个 key 每个短信模板的 最大发送次数 */
	private int maxAllowPerUnit = 30;

	/** 每个 key 每个短信模板 的最小发送间隔 */
	private int minIntervalSeconds = 10;

	/** 缓存服务 */
	private CacheService cacheService;

	/** 缓存前缀 */
	private String cachePrefix = "msgSender";

	/** 缓存过期时间 */
	private int cacheSeconds = 24 * 3600;

	/** 当 KEY 为空白是，是否拒绝发送 */
	private boolean rejectIfKeyBlank = true;

	/** 当拒绝发送时，是否抛出异常 */
	private boolean errorIfReject = false;

	/** 当发送为成功时，是否也启动限流 */
	private boolean limitOnFail = false;

    public RateLimitMsgSender(String tag) {
        this.tag = tag;
    }

	@Override
	public String name() {
		MsgSender sender = msgSender;
		return sender != null ? sender.name() : "Unknow";
	}

	@Override
	public SendResult send(String phone, String tmplName, String... tmplArgs) {
		if (msgSender == null) {
			throw new IllegalStateException("您必须为 RateLimitMsgSender 设置一个 msgSender");
		}
		if (cacheService == null) {
			throw new IllegalStateException("您必须为 RateLimitMsgSender 设置一个 cacheService");
		}
		String key = keyResolver.revolve(phone);
		if (StringUtils.isBlank(key)) {
			if (rejectIfKeyBlank) {
				return reject(phone, tmplName, key, "KEY 为空白");
			}
			return doSend(key, phone, tmplName, tmplArgs);
		}
		if (whiteList.contains(key)) {
			return doSend(key, phone, tmplName, tmplArgs);
		}
		String cacheKey = cachePrefix + ":" + key + ":" + tmplName;
		CacheItem item = cacheService.cache(cacheKey, CacheItem.class);
		long now = System.currentTimeMillis() / 1000;
		if (item != null) {
			long timeDiff = now - item.getLastSentTime();
			if (timeDiff < minIntervalSeconds) {
				return reject(phone, tmplName, key, "过于频繁: " + timeDiff);
			}
			if (item.getCountTime() < now - unitSeconds) {
				item.setCount(0);
				item.setCountTime(now);
			}
			if (item.getCount() >= maxAllowPerUnit) {
				return reject(phone, tmplName, key, "到达最大条数: " + maxAllowPerUnit);
			}
		}
		SendResult result = doSend(key, phone, tmplName, tmplArgs);
		if (result.isOk() || limitOnFail) {
			if (item == null) {
				item = new CacheItem();
				item.setCountTime(now);
			}
			item.setLastSentTime(now);
			item.setCount(item.getCount() + 1);
			cacheService.cache(cacheKey, cacheSeconds, item);
		}
		return result;
	}

	protected SendResult doSend(String key, String phone, String tmplName, String... tmplArgs) {
        log.info("{} 放行 [P: {}, T: {}, K: {}]", tag, phone, tmplName, key);
		return msgSender.send(phone, tmplName, tmplArgs);
	}

	protected RateLimitRejection reject(String phone, String tmplName, String key, String tip) {
		String message = tag + " 拦截 [P: " + phone + ", T: " + tmplName + ", K: " + key + "] " + tip;
		if (errorIfReject) {
			throw new RateLimitException(message);
		}
		log.warn(message);
		return new RateLimitRejection(message);
	}
	
	public static class CacheItem {
		
		private int count;
		private long lastSentTime;
		private long countTime;
		
		public int getCount() {
			return count;
		}
		
		public void setCount(int count) {
			this.count = count;
		}

		public long getLastSentTime() {
			return lastSentTime;
		}

		public void setLastSentTime(long lastSentTime) {
			this.lastSentTime = lastSentTime;
		}

		public long getCountTime() {
			return countTime;
		}

		public void setCountTime(long countTime) {
			this.countTime = countTime;
		}
		
	}

	public MsgSender getMsgSender() {
		return msgSender;
	}

	public void setMsgSender(MsgSender msgSender) {
		this.msgSender = msgSender;
	}

	public KeyResolver getKeyResolver() {
		return keyResolver;
	}

	public void setKeyResolver(KeyResolver keyResolver) {
		this.keyResolver = Objects.requireNonNull(keyResolver);
	}

	public CacheService getCacheService() {
		return cacheService;
	}

	public void setCacheService(CacheService cacheService) {
		this.cacheService = cacheService;
	}

	public String getCachePrefix() {
		return cachePrefix;
	}

	public void setCachePrefix(String cachePrefix) {
		this.cachePrefix = cachePrefix;
	}

	public int getMaxAllowPerUnit() {
		return maxAllowPerUnit;
	}

	public void setMaxAllowPerUnit(int maxAllowPerUnit) {
		this.maxAllowPerUnit = maxAllowPerUnit;
	}

	public long getUnitSeconds() {
		return unitSeconds;
	}

	public void setUnitSeconds(long unitSeconds) {
		this.unitSeconds = unitSeconds;
	}

	public int getCacheSeconds() {
		return cacheSeconds;
	}

	public void setCacheSeconds(int cacheSeconds) {
		this.cacheSeconds = cacheSeconds;
	}

	public int getMinIntervalSeconds() {
		return minIntervalSeconds;
	}

	public void setMinIntervalSeconds(int minIntervalSeconds) {
		this.minIntervalSeconds = minIntervalSeconds;
	}

	public List<String> getWhiteList() {
		return whiteList;
	}

	public void setWhiteList(List<String> whiteList) {
		this.whiteList = Objects.requireNonNull(whiteList);
	}

	public boolean isRejectIfKeyBlank() {
		return rejectIfKeyBlank;
	}

	public void setRejectIfKeyBlank(boolean rejectIfKeyBlank) {
		this.rejectIfKeyBlank = rejectIfKeyBlank;
	}

	public boolean isErrorIfReject() {
		return errorIfReject;
	}

	public void setErrorIfReject(boolean errorIfReject) {
		this.errorIfReject = errorIfReject;
	}

	public boolean isLimitOnFail() {
		return limitOnFail;
	}

	public void setLimitOnFail(boolean limitOnFail) {
		this.limitOnFail = limitOnFail;
	}

}
